#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from functools import wraps
from pathlib import PurePosixPath, PureWindowsPath
from shutil import copyfileobj
from typing import TYPE_CHECKING, Any, Literal

import smbclient

from airflow.providers.common.compat.sdk import BaseHook

if TYPE_CHECKING:
    import smbprotocol.connection


class SambaHook(BaseHook):
    """
    Allows for interaction with a Samba server.

    The hook should be used as a context manager in order to correctly
    set up a session and disconnect open connections upon exit.

    :param samba_conn_id: The connection id reference.
    :param share:
        An optional share name. If this is unset then the "schema" field of
        the connection is used in its place.
    :param share_type:
        An optional share type name. If this is unset then it will assume a posix share type.
    """

    conn_name_attr = "samba_conn_id"
    default_conn_name = "samba_default"
    conn_type = "samba"
    hook_name = "Samba"

    def __init__(
        self,
        samba_conn_id: str = default_conn_name,
        share: str | None = None,
        share_type: Literal["posix", "windows"] | None = None,
    ) -> None:
        super().__init__()
        conn = self.get_connection(samba_conn_id)

        if not conn.login:
            self.log.info("Login not provided")

        if not conn.password:
            self.log.info("Password not provided")

        self._share_type = share_type or conn.extra_dejson.get("share_type", "posix")
        if self._share_type not in {"posix", "windows"}:
            self._share_type = "posix"
            self.log.warning(
                "Invalid share_type specified. It must be either 'posix' or 'windows'. Falling back to 'posix'"
            )

        connection_cache: dict[str, smbprotocol.connection.Connection] = {}

        self._host = conn.host
        self._share = share or conn.schema
        self._connection_cache = connection_cache
        self._conn_kwargs = {
            "username": conn.login,
            "password": conn.password,
            "port": conn.port or 445,
            "connection_cache": connection_cache,
        }

    def __enter__(self):
        # This immediately connects to the host (which can be
        # perceived as a benefit), but also help work around an issue:
        #
        # https://github.com/jborean93/smbprotocol/issues/109.
        smbclient.register_session(self._host, **self._conn_kwargs)
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        for host, connection in self._connection_cache.items():
            self.log.info("Disconnecting from %s", host)
            connection.disconnect()
        self._connection_cache.clear()

    @staticmethod
    def _join_posix_path(host: str, share: str, path: str) -> str:
        return str(PurePosixPath("//" + host, share, path.lstrip("/")))

    @staticmethod
    def _join_windows_path(host: str, share: str, path: str) -> str:
        return "\\{}".format(PureWindowsPath(f"\\{host}\\{share}", path.lstrip(r"\/")))

    def _join_path(self, path):
        if self._share_type == "windows":
            return self._join_windows_path(self._host, self._share, path)
        return self._join_posix_path(self._host, self._share, path)

    @wraps(smbclient.link)
    def link(self, src, dst, follow_symlinks=True):
        return smbclient.link(
            self._join_path(src),
            self._join_path(dst),
            follow_symlinks=follow_symlinks,
            **self._conn_kwargs,
        )

    @wraps(smbclient.listdir)
    def listdir(self, path):
        return smbclient.listdir(self._join_path(path), **self._conn_kwargs)

    @wraps(smbclient.lstat)
    def lstat(self, path):
        return smbclient.lstat(self._join_path(path), **self._conn_kwargs)

    def makedirs(self, path, exist_ok=False):
        self._makedirs(path, exist_ok)

    @wraps(smbclient.makedirs)
    def _makedirs(self, path, exist_ok):
        return smbclient.makedirs(self._join_path(path), exist_ok=exist_ok, **self._conn_kwargs)

    @wraps(smbclient.mkdir)
    def mkdir(self, path):
        return smbclient.mkdir(self._join_path(path), **self._conn_kwargs)

    def open_file(
        self,
        path,
        mode="r",
        buffering=-1,
        encoding=None,
        errors=None,
        newline=None,
        share_access=None,
        desired_access=None,
        file_attributes=None,
        file_type="file",
    ):
        return self._open_file(
            path,
            mode,
            buffering,
            encoding,
            errors,
            newline,
            share_access,
            desired_access,
            file_attributes,
            file_type,
        )

    @wraps(smbclient.open_file)
    def _open_file(
        self,
        path,
        mode,
        buffering,
        encoding,
        errors,
        newline,
        share_access,
        desired_access,
        file_attributes,
        file_type="file",
    ):
        return smbclient.open_file(
            self._join_path(path),
            mode=mode,
            buffering=buffering,
            encoding=encoding,
            errors=errors,
            newline=newline,
            share_access=share_access,
            desired_access=desired_access,
            file_attributes=file_attributes,
            file_type=file_type,
            **self._conn_kwargs,
        )

    @wraps(smbclient.readlink)
    def readlink(self, path):
        return smbclient.readlink(self._join_path(path), **self._conn_kwargs)

    @wraps(smbclient.remove)
    def remove(self, path):
        return smbclient.remove(self._join_path(path), **self._conn_kwargs)

    @wraps(smbclient.removedirs)
    def removedirs(self, path):
        return smbclient.removedirs(self._join_path(path), **self._conn_kwargs)

    @wraps(smbclient.rename)
    def rename(self, src, dst):
        return smbclient.rename(self._join_path(src), self._join_path(dst), **self._conn_kwargs)

    @wraps(smbclient.replace)
    def replace(self, src, dst):
        return smbclient.replace(self._join_path(src), self._join_path(dst), **self._conn_kwargs)

    @wraps(smbclient.rmdir)
    def rmdir(self, path):
        return smbclient.rmdir(self._join_path(path), **self._conn_kwargs)

    @wraps(smbclient.scandir)
    def scandir(self, path, search_pattern="*"):
        return smbclient.scandir(
            self._join_path(path),
            search_pattern=search_pattern,
            **self._conn_kwargs,
        )

    @wraps(smbclient.stat)
    def stat(self, path, follow_symlinks=True):
        return smbclient.stat(self._join_path(path), follow_symlinks=follow_symlinks, **self._conn_kwargs)

    @wraps(smbclient.stat_volume)
    def stat_volume(self, path):
        return smbclient.stat_volume(self._join_path(path), **self._conn_kwargs)

    @wraps(smbclient.symlink)
    def symlink(self, src, dst, target_is_directory=False):
        return smbclient.symlink(
            self._join_path(src),
            self._join_path(dst),
            target_is_directory=target_is_directory,
            **self._conn_kwargs,
        )

    @wraps(smbclient.truncate)
    def truncate(self, path, length):
        return smbclient.truncate(self._join_path(path), length, **self._conn_kwargs)

    @wraps(smbclient.unlink)
    def unlink(self, path):
        return smbclient.unlink(self._join_path(path), **self._conn_kwargs)

    @wraps(smbclient.utime)
    def utime(self, path, times=None, ns=None, follow_symlinks=True):
        return smbclient.utime(
            self._join_path(path),
            times=times,
            ns=ns,
            follow_symlinks=follow_symlinks,
            **self._conn_kwargs,
        )

    @wraps(smbclient.walk)
    def walk(self, path, topdown=True, onerror=None, follow_symlinks=False):
        return smbclient.walk(
            self._join_path(path),
            topdown=topdown,
            onerror=onerror,
            follow_symlinks=follow_symlinks,
            **self._conn_kwargs,
        )

    @wraps(smbclient.getxattr)
    def getxattr(self, path, attribute, follow_symlinks=True):
        return smbclient.getxattr(
            self._join_path(path), attribute, follow_symlinks=follow_symlinks, **self._conn_kwargs
        )

    @wraps(smbclient.listxattr)
    def listxattr(self, path, follow_symlinks=True):
        return smbclient.listxattr(
            self._join_path(path), follow_symlinks=follow_symlinks, **self._conn_kwargs
        )

    @wraps(smbclient.removexattr)
    def removexattr(self, path, attribute, follow_symlinks=True):
        return smbclient.removexattr(
            self._join_path(path), attribute, follow_symlinks=follow_symlinks, **self._conn_kwargs
        )

    @wraps(smbclient.setxattr)
    def setxattr(self, path, attribute, value, flags=0, follow_symlinks=True):
        return smbclient.setxattr(
            self._join_path(path),
            attribute,
            value,
            flags=flags,
            follow_symlinks=follow_symlinks,
            **self._conn_kwargs,
        )

    def push_from_local(self, destination_filepath: str, local_filepath: str, buffer_size: int | None = None):
        """
        Push local file to samba server.

        :param destination_filepath: the samba location to push to
        :param local_filepath: the file to push
        :param buffer_size:
            size in bytes of the individual chunks of file to send. Larger values may
            speed up large file transfers
        """
        extra_args = (buffer_size,) if buffer_size else ()
        with open(local_filepath, "rb") as f, self.open_file(destination_filepath, mode="wb") as g:
            copyfileobj(f, g, *extra_args)

    @classmethod
    def get_ui_field_behaviour(cls) -> dict[str, Any]:
        """Return custom field behaviour."""
        return {
            "hidden_fields": [],
            "relabeling": {"schema": "Share"},
        }

    @classmethod
    def get_connection_form_widgets(cls) -> dict[str, Any]:
        """Return connection widgets to add to connection form."""
        from flask_babel import lazy_gettext
        from wtforms import StringField

        return {
            "share_type": StringField(
                label=lazy_gettext("Share Type"),
                description="The share OS type (`posix` or `windows`). Used to determine the formatting of file and folder paths.",
                default="posix",
            )
        }
